#!/usr/bin/env python

import time
from math import sqrt, log10
import numpy as np
from astropy.time import Time
import matplotlib as mpl
import matplotlib.pyplot as plt
import click
import specio
import func


class Measure:
    def __init__(self, specname=None, windowfile=None, redshift=0.0):
        """
        if windowfile is not None, the redshift will not work.
        """
        if specname is not None:
            self.read_spec(specname)
        self.update_window(windowfile, redshift)
        self.contl = None
        self.contr = None
        self.lineflux = None
        self.contlerr = None
        self.contrerr = None
        self.lineerr = None

    def update_window(self, windowfile=None, redshift=0.0):
        self.redshift = redshift
        if windowfile is None:
            z = self.redshift+1
            self.left_window = [4740*z, 4790*z]
            self.right_window = [5075*z, 5125*z]
            self.line_window = [4810*z, 4910*z]
        else:
            lst = open(windowfile).readlines()
            self.redshift = float(lst[1].split()[0])
            conl1 = float(lst[2].split()[0])
            conl2 = float(lst[2].split()[1])
            conr1 = float(lst[3].split()[0])
            conr2 = float(lst[3].split()[1])
            line1 = float(lst[4].split()[0])
            line2 = float(lst[4].split()[1])
            z = 1+self.redshift
            self.left_window = [conl1*z, conl2*z]
            self.right_window = [conr1*z, conr2*z]
            self.line_window = [line1*z, line2*z]

    def read_spec(self, specname):
        self.specname = specname
        self.spectrum = specio.Spectrum(specname)

    def get_conterr(self, flux):
        conef = sqrt(len(flux))
        return np.std(flux) / conef

    def get_flux(self):
        wave = self.spectrum.wave
        flux = self.spectrum.flux
        err = self.spectrum.err
        argl = np.where((wave>self.left_window[0]) &
                        (wave<self.left_window[1]))
        conl = np.median(flux[argl])
        wl = np.mean(self.left_window)
        argr = np.where((wave>self.right_window[0]) &
                        (wave<self.right_window[1]))
        conr = np.median(flux[argr])
        wr = np.mean(self.right_window)
        step = np.diff(wave)
        argline = np.where((wave>self.line_window[0]) &
                           (wave<self.line_window[1]))
        wline = wave[argline]
        profileline = flux[argline]
        conline = np.interp(wline, [wl, wr], [conl, conr])
        fluxline = profileline - conline
        stepline = step[argline]
        intflux = np.sum(fluxline*stepline)
        errline = err[argline]
        self.contl = conl
        self.contr = conr
        self.lineflux = intflux
        self.contlerr = self.get_conterr(flux[argl])
        self.contrerr = self.get_conterr(flux[argr])
        hwidth = (wr - wl)*0.5
        partialerr2 = (self.contrerr**2+self.contlerr**2)*hwidth**2
        self.lineerr = (np.sum((errline*stepline)**2)+partialerr2)**0.5
        return self.contl, self.contr, self.lineflux

    def get_err(self):
        if self.lineerr is None:
            self.get_flux()
        return self.contlerr, self.contrerr, self.lineerr

    def get_jd(self):
        try:
            return self.spectrum.header['jd']
        except KeyError:
            date = self.spectrum.header['DATE-OBS']
            utime = self.spectrum.header['UT']
            sdt = date + 'T' + utime.strip()
            t = Time(sdt, format='isot')
            return t.jd

    def show(self, fig=None):
        if fig is None:
            fig = plt.gcf()
        fig.clf()
        ax = fig.add_subplot(111)
        ax.set_title(self.specname)
        ax.step(self.spectrum.wave, self.spectrum.flux,
                where='mid', color='b')
        ax.errorbar(self.spectrum.wave, self.spectrum.flux,
                    yerr=self.spectrum.err, fmt='.', capsize=1, color='b')
        ax.axvline(self.left_window[0], linestyle='--', color='r')
        ax.axvline(self.left_window[1], linestyle='--', color='r')
        ax.axvline(self.right_window[0], linestyle='--', color='r')
        ax.axvline(self.right_window[1], linestyle='--', color='r')
        ax.axvline(self.line_window[0], linestyle='--', color='g')
        ax.axvline(self.line_window[1], linestyle='--', color='g')
        wl = np.mean(self.left_window)
        wr = np.mean(self.right_window)
        ax.plot([wl], [self.contl], marker='o', color='r', markersize=7)
        ax.plot([wr], [self.contr], marker='o', color='r', markersize=7)
        ax.plot([wl, wr], [self.contl, self.contr], color='g',
                linestyle=':', lw=2)
        marginx = (wr-wl)*0.8
        xlimt1 = wl - marginx
        xlimt2 = wr + marginx
        wave = self.spectrum.wave
        flux = self.spectrum.flux
        err = self.spectrum.err
        arg = np.where((wave>xlimt1) & (wave<xlimt2))
        maxy = np.max(flux[arg]+err[arg])
        miny = np.min(flux[arg]-err[arg])
        marginy = (maxy - miny) * 0.2
        ylimt1 = miny - marginy
        ylimt2 = maxy + marginy
        ax.set_xlim(xlimt1, xlimt2)
        ax.set_ylim(ylimt1, ylimt2)
        fig.canvas.draw()
        fig.canvas.flush_events()
        fig.show()


def get_pos(ax):
    x1, x2 = ax.get_xlim()
    y1, y2 = ax.get_ylim()
    width = x2 - x1
    hight = y2 - y1
    sx = 0.965 * width
    sy = 0.78 * hight
    px = x1 + sx
    py = y1 + sy
    return px, py


def plot_lc(jd, fc, fcerr, fl, flerr):

    jd = np.array(jd)
    fc = np.array(fc)
    fcerr = np.array(fcerr)
    fl = np.array(fl)
    flerr = np.array(flerr)

    mpl.rcParams['xtick.direction'] = 'in'
    mpl.rcParams['ytick.direction'] = 'in'
    mpl.rcParams['xtick.major.size'] = 6
    mpl.rcParams['xtick.major.width'] = 1.5
    mpl.rcParams['ytick.major.size'] = 6
    mpl.rcParams['ytick.major.width'] = 1.5
    mpl.rcParams['xtick.minor.size'] = 3
    mpl.rcParams['xtick.minor.width'] = 0.9
    mpl.rcParams['ytick.minor.size'] = 3
    mpl.rcParams['ytick.minor.width'] = 0.9
    mpl.rcParams['legend.fontsize'] = 12

    fontsiz = 14
    ticklabelsize = 13

    plt.rc('text', usetex=True)
    plt.rc('font', family='serif')
    fig = plt.figure(figsize=(10, 5))
    x1 = 0.08
    y1 = 0.1
    w1 = 0.55
    h1 = 0.4
    sx = 0.01
    sy = 0.01
    w2 = 0.335
    ax0 = fig.add_axes([x1, y1+h1+sy, w1+w2+sx, h1])
    ax1 = fig.add_axes([x1, y1, w1+w2+sx, h1], sharex=ax0)
    ax1.yaxis.set_ticks_position('both')
    ax1.xaxis.set_ticks_position('both')
    ax1.tick_params(which='major', labelsize=ticklabelsize)
    plt.setp(ax1.get_xticklabels(), visible=True)

    datefrom = int(np.min(jd)/100)*100
    fcunit = func.get_unit(fc)
    flunit = func.get_unit(fl)

    ax1.set_ylabel(r'Continuum ($\times 10^{%d}$)' % int(log10(fcunit)),
                   fontsize=fontsiz)
    ax1.minorticks_on()
    ax1.set_xlabel('JD - %d (days)' % datefrom, fontsize=fontsiz)
    plt.subplots_adjust(hspace=0.05)
    ax1.errorbar(jd-datefrom, fc/fcunit, yerr=fcerr/fcunit, fmt='o',
                 color="C1", ms=5, linewidth=0.5)

    ax0.yaxis.set_ticks_position('both')
    ax0.xaxis.set_ticks_position('both')
    ax0.tick_params(which='major', labelsize=ticklabelsize)
    ax0.errorbar(jd-datefrom, fl/flunit, yerr=flerr/flunit, fmt='o',
                 color='C0', ms=5, linewidth=0.5)
    ax0.set_ylabel(r'Emission ($\times 10^{%d}$)' % int(log10(flunit)),
                   fontsize=fontsiz)
    ax0.tick_params(which='major', labelsize=ticklabelsize)
    ax0.minorticks_on()

    xticklst, labellst = [], []
    jdmin = np.min(jd)
    jdmax = np.max(jd)
    yearfrom = Time(jdmin, format='jd').datetime.year + 1
    yearend = Time(jdmax, format='jd').datetime.year + 1
    for date in range(yearfrom, yearend):
        datestr = '%d-01-01' % date
        jd = Time(datestr).jd
        npx = jd - datefrom
        xticklst.append(npx)
        labellst.append(str(date))
    axt = ax0.twiny()
    plt.setp(ax0.get_xticklabels(), visible=False)
    axt.set_xticks(xticklst)
    axt.set_xticklabels(labellst)
    axt.minorticks_on()
    axt.tick_params(which='major', labelsize=ticklabelsize)
    xl, xu = ax0.get_xlim()
    axt.set_xlim(xl, xu)


@click.command()
@click.argument('lstname')
@click.option('--show/--no-show', default=True,
              help='show the dynamical measuring window, default=show')
@click.option('-z', '--redshift', default=0.0, help='redshift, default = 0')
@click.option('--winfile', default='default',
              help=('the config file of continuum and emission line windows. '
                    'if seted, the --redshift option will not work. '
                    'Default windows are \n[4740, 4790]\n[5075, 5125]\n'
                    '[4810, 4910]'))
@click.option('--out', default='default',
              help=('out lc file name, default=lstname+".lc", the data of each '
                    'column are jd, continuum left, continuum left err, '
                    'continuum right, continuum right err, emission line, '
                    'emission line err'))
@click.option('--pause', is_flag=True, default=False, help='pause at every spectrum')
def main(lstname, show, redshift, winfile, out, pause):
    """
    Measure the continuum and emission line light curves from a set of spectra.
    """
    namelst = [i.strip() for i in open(lstname).readlines()]
    namelst = [i for i in namelst if i[0] != '#']
    plt.ion()
    if winfile == 'default':
        winfile = None
    data = []
    for name in namelst:
        meas = Measure(name, windowfile=winfile, redshift=redshift)
        conl, conr, line = meas.get_flux()
        conlerr, conrerr, lineerr = meas.get_err()
        jd = meas.get_jd()
        data.append([jd, conl, conlerr, conr, conrerr, line, lineerr])
        if show is True:
            meas.show()
        if pause is True:
            input('press any key to continue')
    data = np.array(data)
    if out == 'default':
        outname = lstname + '.lc'
    else:
        outname = out
    np.savetxt(outname, data, fmt='%.4f  %.6e  %.6e  %.6e  %.6e  %.6e  %.6e',
               header='jd  cl  clerr  cr  crerr  line  lineerr')

    plt.ioff()
    plot_lc(data[:, 0], data[:, 3], data[:, 4], data[:, 5], data[:, 6])
    plt.show()

if __name__ == "__main__":
    main()
